package com.suqiu.qqrobot.utils;

import com.unfbx.chatgpt.OpenAiClient;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.completions.Completion;
import com.unfbx.chatgpt.entity.completions.CompletionResponse;
import com.unfbx.chatgpt.entity.models.Model;
import com.unfbx.chatgpt.interceptor.OpenAILogger;
import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.util.EntityUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;

/**
 * @author suqiu
 * @description http工具类
 * @since 2022/09/15 12:11
 */
public class HttpUtil {
    private static final CloseableHttpClient HTTP_CLIENT;

    static {
        PoolingHttpClientConnectionManager cm = new PoolingHttpClientConnectionManager();
        cm.setMaxTotal(100);
        cm.setDefaultMaxPerRoute(20);
        cm.setDefaultMaxPerRoute(50);
        HTTP_CLIENT = HttpClients.custom().setConnectionManager(cm).build();
    }

    public static String get(String url) {
        CloseableHttpResponse response = null;
        BufferedReader in;
        String result = "";
        try {
            HttpGet httpGet = new HttpGet(url);
            RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(30000).setConnectionRequestTimeout(30000).setSocketTimeout(30000).build();
            httpGet.setConfig(requestConfig);
            httpGet.setConfig(requestConfig);
            httpGet.addHeader("Content-type", "application/json; charset=utf-8");
            httpGet.setHeader("Accept", "application/json");
            response = HTTP_CLIENT.execute(httpGet);
            in = new BufferedReader(new InputStreamReader(response.getEntity().getContent()));
            StringBuilder sb = new StringBuilder("");
            String line;
            String nL = System.getProperty("line.separator");
            while ((line = in.readLine()) != null) {
                sb.append(line).append(nL);
            }
            in.close();
            result = sb.toString();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                if (null != response) {
                    response.close();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    public static String post(String url, String jsonString) {
        HttpPost httpPost = new HttpPost(url);
        CloseableHttpClient client = HttpClients.createDefault();
        //解决中文乱码问题
        StringEntity entity = new StringEntity(jsonString, "utf-8");
        entity.setContentEncoding("utf-8");
        entity.setContentType("application/json");
        httpPost.setEntity(entity);
        HttpResponse response;
        try {
            response = client.execute(httpPost);
            if (response.getStatusLine().getStatusCode() == 200) {
                HttpEntity httpEntity = response.getEntity();
                return EntityUtils.toString(httpEntity, "utf-8");
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;

    }

    public static String chatCompletion(String question) {
        //国内访问需要做代理，国外服务器不需要，host填入代理IP，如果本地开vpn，一般就是本机ip地址，port根据vpn的port填写，一般是7890
        Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
        HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
        httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.BODY);
        OkHttpClient okHttpClient = new OkHttpClient
                .Builder()
                //自定义代理
                .proxy(proxy)
                //自定义日志输出
                .addInterceptor(httpLoggingInterceptor)
                //自定义返回值拦截
                .addInterceptor(new OpenAiResponseInterceptor())
                //自定义超时时间
                .connectTimeout(10, TimeUnit.SECONDS)
                //自定义超时时间
                .writeTimeout(30, TimeUnit.SECONDS)
                //自定义超时时间
                .readTimeout(30, TimeUnit.SECONDS)
                .build();
        //构建客户端，apiKey中填入获取到的OpenAI的key
        OpenAiClient openAiClient = OpenAiClient.builder()
                .apiKey(Arrays.asList("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"))
                .okHttpClient(okHttpClient)
                .build();

        List<Message> messages = new ArrayList<>();
        //聊天模型：gpt-3.5
        Message message = Message.builder().role(Message.Role.ASSISTANT).content(question).build();
        messages.add(message);
        ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).build();
        ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
//        List<Model> models = openAiClient.models();
//        System.out.println(models);
        System.out.println(chatCompletionResponse);
        return chatCompletionResponse.getChoices().get(0).getMessage().getContent();
//        chatCompletionResponse.getChoices().forEach(e -> {
//            System.out.println("AI: " + e.getMessage().getContent());
//        });
    }

    public static void main(String[] args) {
        //国内访问需要做代理，国外服务器不需要，host填入代理IP，如果本地开vpn，一般就是本机ip地址，port根据vpn的port填写，一般是7890
        Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
        HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
        httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.BODY);
        OkHttpClient okHttpClient = new OkHttpClient
                .Builder()
                //自定义代理
                .proxy(proxy)
                //自定义日志输出
                .addInterceptor(httpLoggingInterceptor)
                //自定义返回值拦截
                .addInterceptor(new OpenAiResponseInterceptor())
                //自定义超时时间
                .connectTimeout(10, TimeUnit.SECONDS)
                //自定义超时时间
                .writeTimeout(30, TimeUnit.SECONDS)
                //自定义超时时间
                .readTimeout(30, TimeUnit.SECONDS)
                .build();
        OpenAiClient openAiClient = OpenAiClient.builder()
                .apiKey(Arrays.asList("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"))
                .okHttpClient(okHttpClient)
                .build();
        Completion completion = Completion.builder().prompt("你会写代码吗？请使用Java写一个helloworld").build();
        CompletionResponse completions = openAiClient.completions(completion);
        System.out.println(completions);
    }
}
